{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Writing a Dataset to AIS in WDs format \n",
    "\n",
    "In this notebook we will download and store the following datasets in [WebDataset](https://github.com/webdataset/webdataset) format in AIS:\n",
    "\n",
    "- [The Oxford-IIIT Pet Dataset](https://academictorrents.com/details/b18bbd9ba03d50b0f7f479acc9f4228a408cecc1)\n",
    "- [Flickr Image dataset](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install aistore"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from aistore.client import Client\n",
    "\n",
    "ais_url = os.getenv(\"AIS_ENDPOINT\", \"http://localhost:8080\")\n",
    "client = Client(ais_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Oxford-IIIT Pet Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import tarfile\n",
    "import os\n",
    "\n",
    "\n",
    "def download_and_extract(url, dest_path):\n",
    "    response = requests.get(url, stream=True)\n",
    "    if response.status_code == 200:\n",
    "        with open(dest_path, \"wb\") as f:\n",
    "            f.write(response.raw.read())\n",
    "        with tarfile.open(dest_path) as tar:\n",
    "            tar.extractall(path=os.path.dirname(dest_path))\n",
    "        os.remove(dest_path)  # Clean up the tar file after extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_url = \"http://www.robots.ox.ac.uk/~vgg/data/pets/data\"\n",
    "images_url = f\"{base_url}/images.tar.gz\"\n",
    "annotations_url = f\"{base_url}/annotations.tar.gz\"\n",
    "\n",
    "data_dir = \"/data\"\n",
    "images_path = os.path.join(data_dir, \"images.tar.gz\")\n",
    "annotations_path = os.path.join(data_dir, \"annotations.tar.gz\")\n",
    "\n",
    "if not os.path.exists(data_dir):\n",
    "    os.makedirs(data_dir)\n",
    "\n",
    "download_and_extract(images_url, images_path)\n",
    "download_and_extract(annotations_url, annotations_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a bucket and writing the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from aistore.sdk.dataset.dataset_config import DatasetConfig\n",
    "from aistore.sdk.dataset.data_attribute import DataAttribute\n",
    "from aistore.sdk.dataset.label_attribute import LabelAttribute\n",
    "\n",
    "bucket = client.bucket(\"pets-dataset\").create(exist_ok=True)\n",
    "base_path = Path(\"/data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to get label from the annotation file\n",
    "\n",
    "\n",
    "def get_class_dict(path: Path):\n",
    "    parsed_dict = {}\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as file:\n",
    "        for line in file.readlines():\n",
    "            if line[0] == \"#\":\n",
    "                continue\n",
    "            file_name, label = line.split(\" \")[:2]\n",
    "            parsed_dict[file_name] = label\n",
    "\n",
    "    return parsed_dict\n",
    "\n",
    "\n",
    "parsed_dict = get_class_dict(base_path.joinpath(\"annotations\").joinpath(\"list.txt\"))\n",
    "\n",
    "\n",
    "def get_label_for_filename(filename):\n",
    "    return parsed_dict.get(filename, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = DatasetConfig(\n",
    "    primary_attribute=DataAttribute(\n",
    "        path=base_path.joinpath(\"images\"), file_type=\"jpg\", name=\"image\"\n",
    "    ),\n",
    "    secondary_attributes=[\n",
    "        DataAttribute(\n",
    "            path=base_path.joinpath(\"annotations\").joinpath(\"trimaps\"),\n",
    "            file_type=\"png\",\n",
    "            name=\"trimap\",\n",
    "        ),\n",
    "        LabelAttribute(\n",
    "            name=\"cls\",\n",
    "            label_identifier=get_label_for_filename,\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "\n",
    "bucket.write_dataset(config=dataset_config, pattern=\"img_dataset\", maxcount=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Flickr Image dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**NOTE:** We are using the [kaggle API](https://github.com/Kaggle/kaggle-api/blob/main/docs/README.md) to download the dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install kaggle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!kaggle datasets download -d hsankesara/flickr-image-dataset -p /data --unzip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a bucket and writing the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from aistore.sdk.dataset.dataset_config import DatasetConfig\n",
    "from aistore.sdk.dataset.data_attribute import DataAttribute\n",
    "from aistore.sdk.dataset.label_attribute import LabelAttribute\n",
    "\n",
    "bucket = client.bucket(\"flickr-dataset\").create(exist_ok=True)\n",
    "base_path = Path(\"/data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to get the caption from results.csv\n",
    "def parse_csv(path: Path):\n",
    "    parsed_dict = {}\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as file:\n",
    "        for line in file:\n",
    "            splitted = line.split(\"|\")\n",
    "            if len(splitted) < 3:\n",
    "                continue\n",
    "            filename = splitted[0].strip().split(\".\")[0]\n",
    "            caption = splitted[2].strip()\n",
    "            parsed_dict[filename] = caption\n",
    "    return parsed_dict\n",
    "\n",
    "\n",
    "parsed_dict = parse_csv(base_path.joinpath(\"flickr30k_images/results.csv\"))\n",
    "\n",
    "\n",
    "def get_caption_for_filename(filename):\n",
    "    return parsed_dict.get(filename, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = DatasetConfig(\n",
    "    primary_attribute=DataAttribute(\n",
    "        path=base_path.joinpath(\"flickr30k_images/flickr30k_images\"),\n",
    "        file_type=\"jpg\",\n",
    "        name=\"image\",\n",
    "    ),\n",
    "    secondary_attributes=[\n",
    "        LabelAttribute(\n",
    "            name=\"caption\",\n",
    "            label_identifier=get_caption_for_filename,\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "\n",
    "bucket.write_dataset(config=dataset_config, pattern=\"flickr_dataset\", maxcount=1000)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "my-python3-kernel",
   "language": "python",
   "name": "my-python3-kernel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
